// Copyright 2019 DeepMind Technologies Ltd. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

import XCTest
import OpenSpiel

final class ExploitabilityDescentTests: XCTestCase {
  func testKuhnPokerConvergence() throws {
    let game = KuhnPoker(playerCount: 2)
    let informationStateCache = InformationStateCache(game)
    informationStateCache.populate()
    var logitTable = TensorFlowLogitTable(informationStateCache)
    var actionValueCalculator = ActionValueCalculator(informationStateCache)
    var nashConvs: [Double] = []
#if false
    // NOTE: Test is disabled due to a non-differentiability error regarding
    // `inout` arguments: the `self` argument to the mutating function
    // `ActionValueCalculator.loss` is marked as active by activity analysis.
    // This behavior is expected and currently there is no known workaround:
    //
    //     note: cannot differentiate through `inout` arguments
    //             let loss = actionValueCalculator.loss(for: policy)
    //                                              ^
    // https://bugs.swift.org/browse/TF-985 tracks this issue.
    for _ in 0...10 {
      let gradients = gradient(at: logitTable) {
        (logitTable: TensorFlowLogitTable) -> Double in
        let policy = TensorFlowTabularPolicy(logitTable)
        let loss = actionValueCalculator.loss(for: policy)
        nashConvs = nashConvs + [actionValueCalculator.nashConv]
        return -loss
      }
      logitTable.move(along: gradients)
    }
    zip(nashConvs, [
      0.91666666666666652, 0.67893004801213452, 0.48109148836354743,
      0.40061420923255808, 0.36617242161468722, 0.33676996443499557,
      0.30925081512398128, 0.28827843035940964, 0.26830042206858751,
      0.24418597846799289, 0.22168699344791482
    ]).forEach { XCTAssertEqual($0, $1, accuracy: 1e-5) }
#endif
  }

  func testLeducPokerConvergence() throws {
    let game = LeducPoker(playerCount: 2)
    let informationStateCache = InformationStateCache(game)
    informationStateCache.populate()
    var logitTable = TensorFlowLogitTable(informationStateCache)
    var actionValueCalculator = ActionValueCalculator(informationStateCache)
    var nashConvs: [Double] = []
#if false
    // NOTE: Test is disabled due to a non-differentiability error regarding
    // `inout` arguments: the `self` argument to the mutating function
    // `ActionValueCalculator.loss` is marked as active by activity analysis.
    // This behavior is expected and currently there is no known workaround:
    //
    //     note: cannot differentiate through `inout` arguments
    //             let loss = actionValueCalculator.loss(for: policy)
    //                                              ^
    // https://bugs.swift.org/browse/TF-985 tracks this issue.
    for _ in 0...10 {
      let gradients = gradient(at: logitTable) {
        (logitTable: TensorFlowLogitTable) -> Double in
        let policy = TensorFlowTabularPolicy(logitTable)
        let loss = actionValueCalculator.loss(for: policy)
        nashConvs = nashConvs + [actionValueCalculator.nashConv]
        return -loss
      }
      logitTable.move(along: gradients)
    }
    zip(nashConvs, [
      4.7472224, 4.3147216, 3.9900389, 3.7576618, 3.5771275, 3.4414644,
      3.3272073, 3.1898201, 3.1089299, 3.0108435, 2.8992782
    ]).forEach { XCTAssertEqual($0, $1, accuracy: 1e-5) }
#endif
  }
}

extension ExploitabilityDescentTests {
  static var allTests = [
    ("testKuhnPokerConvergence", testKuhnPokerConvergence),
    ("testLeducPokerConvergence", testLeducPokerConvergence)
  ]
}
